"""Set of reach target tasks."""
from abc import ABC

import numpy as np
from gymnasium import spaces
from mojo import Mojo
from mojo.elements import Body, Geom
from mojo.elements.consts import GeomType

from bigym.bigym_env import BiGymEnv, MAX_DISTANCE_FROM_TARGET
from bigym.const import HandSide
from bigym.utils.physics_utils import distance


class Target:
    """Target sphere."""

    def __init__(self, mojo: Mojo, size: np.ndarray, color: np.ndarray):
        """Init."""
        self.body = Body.create(mojo)
        self.geom: Geom = Geom.create(
            mojo,
            parent=self.body,
            geom_type=GeomType.SPHERE,
            size=size,
            color=color,
            mass=0,
        )
        self.geom.set_collidable(False)


class _ReachTargetEnv(BiGymEnv, ABC):
    """Base reach target environment."""

    TARGET_SIZE = np.array([0.05, 0.05, 0.05])
    TARGET_COLOR = np.array([1, 0, 0, 1])
    TARGET_BOUNDS = np.array([0.1, 0.1, 0.1])

    TOLERANCE = 0.1

    def _dist_to_target(self, target: Target, side: HandSide) -> float:
        return float(
            np.linalg.norm(target.body.get_position() - self._robot.get_hand_pos(side))
        )


class ReachTarget(_ReachTargetEnv):
    """Reach target with any hand."""

    TARGET_POSITION = np.array([0.5, 0, 1])

    def _initialize_env(self):
        self.target = Target(self._mojo, self.TARGET_SIZE, self.TARGET_COLOR)

    def _get_task_privileged_obs_space(self):
        return {
            "target_position": spaces.Box(
                low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32
            )
        }

    def _get_task_privileged_obs(self):
        return {
            "target_position": np.array(
                self.target.body.get_position(), np.float32
            ).copy()
        }

    def _success(self) -> bool:
        if self._dist_to_target(self.target, HandSide.LEFT) <= self.TOLERANCE:
            return True
        if self._dist_to_target(self.target, HandSide.RIGHT) <= self.TOLERANCE:
            return True
        return False

    def _fail(self) -> bool:
        return distance(self._robot.pelvis, self.target.body) > MAX_DISTANCE_FROM_TARGET

    def _on_reset(self):
        offset = np.random.uniform(-self.TARGET_BOUNDS, self.TARGET_BOUNDS)
        self.target.body.set_position(self.TARGET_POSITION + offset)


class ReachTargetSingle(ReachTarget):
    """Reach target with specific hand."""

    SIDE = HandSide.LEFT

    def _initialize_env(self):
        super()._initialize_env()

    def _success(self) -> bool:
        return self._dist_to_target(self.target, self.SIDE) <= self.TOLERANCE


class ReachTargetDual(_ReachTargetEnv):
    """Reach two targets."""

    COLOR_LEFT = np.array([0.6, 0, 0, 1])
    COLOR_LEFT_SUCCESS = np.array([1, 0.5, 0.2, 1])
    COLOR_RIGHT = np.array([0, 0.6, 0, 1])
    COLOR_RIGHT_SUCCESS = np.array([0.2, 0.3, 1, 1])

    TARGET_POSITION_LEFT = np.array([0.5, 0.2, 1])
    TARGET_POSITION_RIGHT = np.array([0.5, -0.2, 1])

    def _initialize_env(self):
        self.target_left = Target(self._mojo, self.TARGET_SIZE, self.COLOR_LEFT)
        self.target_right = Target(self._mojo, self.TARGET_SIZE, self.COLOR_RIGHT)

    def _get_task_privileged_obs_space(self):
        return {
            "target_position_left": spaces.Box(
                low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32
            ),
            "target_position_right": spaces.Box(
                low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32
            ),
        }

    def _get_task_privileged_obs(self):
        return {
            "target_position_left": np.array(
                self.target_left.body.get_position(), np.float32
            ).copy(),
            "target_position_right": np.array(
                self.target_right.body.get_position(), np.float32
            ).copy(),
        }

    def _success(self) -> bool:
        if self._dist_to_target(self.target_left, HandSide.LEFT) > self.TOLERANCE:
            return False
        if self._dist_to_target(self.target_right, HandSide.RIGHT) > self.TOLERANCE:
            return False
        return True

    def _fail(self) -> bool:
        return (
            distance(self._robot.pelvis, self.target_left.body)
            > MAX_DISTANCE_FROM_TARGET
        )

    def _on_reset(self):
        offset = np.random.uniform(-self.TARGET_BOUNDS, self.TARGET_BOUNDS)
        self.target_left.body.set_position(self.TARGET_POSITION_LEFT + offset)
        offset = np.random.uniform(-self.TARGET_BOUNDS, self.TARGET_BOUNDS)
        self.target_right.body.set_position(self.TARGET_POSITION_RIGHT + offset)

    def _on_step(self):
        if self._dist_to_target(self.target_left, HandSide.LEFT) > self.TOLERANCE:
            self.target_left.geom.set_color(self.COLOR_LEFT)
        else:
            self.target_left.geom.set_color(self.COLOR_LEFT_SUCCESS)

        if self._dist_to_target(self.target_right, HandSide.RIGHT) > self.TOLERANCE:
            self.target_right.geom.set_color(self.COLOR_RIGHT)
        else:
            self.target_right.geom.set_color(self.COLOR_RIGHT_SUCCESS)
